Pytorch中DataLoader类的多线程实现方法分析 |
您所在的位置:网站首页 › dataloader workers › Pytorch中DataLoader类的多线程实现方法分析 |
Pytorch中DataLoader类的多线程实现⽅法分析 之前在改⾃定义的DataSet的时候,由于在getitem()⾥⾯写了太多操作,导致训练过程贼慢,于是考虑⽤多线程优化⼀下。查阅⼀些资料发现 pytorch在DataLoader⾥⾯就有多线程的实现,只要在定义的时候将num_worker设置成⼤于0就可以了。遂想要探索⼀下pytorch具体的实现 ⽅法。 ⾸先找到迭代器: def __iter__(self): return _DataLoaderIter(self) 初始化: def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) base_seed = torch.LongTensor(1).random_().item() collate_fn:将数据整合成⼀个batch返回的⽅法,⽤户可以⾃定义 batch_sampler:⾃定义如何取样 pin_menory:是否将数据集拷贝到显卡上 done_event:事件管理标志 sample_iter:迭代器,所以batch_sampler应该类似于⽤户⾃定义的⼀个数据的列表,⽤来⽣成可迭代对象sample_iter。 下⾯是与多线程有关的⼀些定义: if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.index_queues = [multiprocessing.Queue() for _ in range(self.num_workers)] self.worker_queue_idx = 0 self.worker_result_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queues[i], self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] worker_init_fn:⽤户定义的每个worker初始化的时候需要执⾏的函数。 index_queues:这⾥⽤到了multiprocessing,pytorch的multiprocessing是对python原⽣的multiprocessing的⼀个封装,不过好像基本 没什么变化。这⾥定义⼀个队列,multiprocessing的Queue类(这个Queue的⽗类)提供了put()和get()⽅法,⽤来向队列中增加线程和移除 线程并返回结果。Pytorch的封装另外提供了send()和recv()⽅法,⽤来接收和读取缓存,具体实现和作⽤这⾥暂且按下不表。通过阅读后⾯的代 |
CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3 |